// Copyright 2023 TikTok Pte. Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "nlohmann/json.hpp"

#include "dpca-psi/crypto/dp_sampling.h"
#include "dpca-psi/crypto/ecc_cipher.h"
#include "dpca-psi/crypto/ipcl_paillier.h"
#include "dpca-psi/crypto/prng.h"
#include "dpca-psi/network/io_base.h"

namespace privacy_go {
namespace dpca_psi {

using json = nlohmann::json;

class DPCardinalityPSI {
public:
    DPCardinalityPSI();

    // DPCardinalityPSI is not copyable.
    DPCardinalityPSI(const DPCardinalityPSI& other) = delete;

    // DPCardinalityPSI is not assignable.
    DPCardinalityPSI& operator=(const DPCardinalityPSI& other) = delete;

    // Initializes parameters and variables according to parameters' json configuration.
    //   1. Generates multiple ECC encryptors and a Paillier encryptor, with secret keys.
    //   2. Exchanges Paillier public keys with the other party.
    // Params of json format is structured as follows:
    /*
    {
        "common": {
            "address": "127.0.0.1",
            "remote_port": 30330,
            "local_port": 30331,
            "timeout": 90,
            "input_file": "example/data/sender_input_file.csv",
            "has_header": false,
            "output_file": "example/data/sender_output_file.csv",
            "ids_num": 3,
            "is_sender": true,
            "verbose": false
        },
        "paillier_params": {
            "paillier_n_len": 2048,
            "enable_djn": true,
            "apply_packing": true,
            "statistical_security_bits": 40
        },
        "ecc_params": {
            "curve_id": NID_X9_62_prime256v1(415)
        },
        "dp_params": {
            "epsilon": 2/4/6/8,
            "maximum_queries": 10,
            "use_precomputed_tau": true,
            "precomputed_tau": 722/954/1194/1440/1690,
            "input_dp": true,
            "has_zero_column": false,
            "zero_column_index": -1
        }
    }
    */
    void init(const json& params, std::shared_ptr<IOBase> net);

    // 1. Exchanges the number of rows and the number of feature columns per row.
    // 2. Samples dummy data and appends them to the original datasets, on both sender's and receiver's side.
    // 3. Generates random permutations of rows.
    void data_sampling(
            const std::vector<std::vector<std::string>>& keys, const std::vector<std::vector<std::uint64_t>>& features);

    // Performs intersection and stores secret shares in shares for both parties.
    // In details, the workflow of dpca-psi:
    //   1. Shuffles and encrypts keys of every row on both parties' side. Exchanges keys with the other party.
    //   2. Reshuffles and doublely encrypts the exchanged keys. Sends back keys to the other party.
    //   3. Computes intersection on the first column and saves indices of the intersection.
    //   4. Iteratively repeats 1~3 for the rest of columns and saves the intersection's indices.
    //   5. Shuffles and encrypts features on both parties' side. Exchanges features with the other party.
    //   6. Generates additive shares of Paillier-encrypted features.
    //   7. Decrypts and converts additive shares in Z_n to additive shares in Z_{2^l}.
    void process(std::vector<std::vector<std::uint64_t>>& shares);

    ~DPCardinalityPSI() {
    }

private:
    // Checks the validity and consistency of json params of both parties.
    void check_params();

    // Permutes the keys with the pattern generated by itself. Encrypts them with ECC encryptors.
    // Stores keys encrypted by the first ECC key in encrypted_keys.
    void shuffle_and_encrypt_keys_round_one(std::vector<std::vector<ByteVector>>& encrypted_keys);

    // Permutes the exchanged keys with the pattern generated by the other. Doublely encrypts them with ECC encryptors.
    // Stores the first column's reshuffled encrypted keys in reshuffled_encrypted_keys.
    void reshuffle_and_encrypt_exchanged_keys_round_one(std::vector<ByteVector>& reshuffled_encrypted_keys);

    // Iteratively repeat the matching procedure for the i-th column, where i is in [2, key_size_].
    //   1. Removes the rows that have been matched in the (i-1)-th matching.
    //   2. Shuffles and encrypts the i-th column's keys on both parties' side. Exchanges the i-th column's keys with
    //      the other party.
    //   3. Doublely encrypts the i-th column's exchanged keys and sends back to the other party.
    //   4. Computes intersection on the i-th column and saves the intersection's indices.
    // Returns the size of the final intersection.
    std::size_t repeatedly_match(std::size_t intersection_round_one);

    // Computes intersection on the 1st column and saves the intersection's indices.
    // Returns the size of the intersection on the 1st column.
    std::size_t calculate_intersection_round_one(
            const std::vector<ByteVector>& encrypted_keys, const std::vector<ByteVector>& exchanged_keys);

    // Computes intersection on the i-th column and saves the intersection's indices.
    // Returns the size of the intersection on the i-th column.
    std::size_t calculate_intersection_round_i(const std::vector<ByteVector>& encrypted_keys,
            const std::vector<ByteVector>& exchanged_keys, const std::vector<std::size_t>& mapping);

    // Permutes the features with the pattern generated by itself. Encrypts them with a Paillier encryptor.
    // Adopts Paillier's ciphertext packing to reduce communication and computation.
    // Stores encrypted features in encrypted_features.
    void shuffle_and_encrypt_features(std::vector<std::vector<ByteVector>>& encrypted_features);

    // Filters out intersect features from all encrypted features according to intersect keys.
    // Stores filtered features in intersection_features.
    void filter_intersection_features(const std::vector<std::vector<ByteVector>>& encrypted_features,
            std::size_t intersection_size, std::vector<std::vector<ByteVector>>& intersection_features);

    // Generates additive shares of Paillier-encrypted features.
    // Stores encrypted additives shares in random_r.
    void generate_additive_shares(IpclPaillier& paillier, std::vector<std::vector<ByteVector>>& encrypted_features,
            std::vector<std::vector<BigNumber>>& random_r);

    // Decrypts and converts additive shares in Z_n to additive shares in Z_{2^l}.
    void decrypt_and_reveal_shares(const std::vector<std::vector<ByteVector>>& encrypetd_shares,
            const std::vector<std::vector<BigNumber>>& random_r, std::size_t intersection_size,
            std::vector<std::vector<std::uint64_t>>& shares);

    // Exchanges encrypted keys or doublely encrypted keys with the other party.
    void exchange_encrypted_keys(const std::vector<std::vector<ByteVector>>& encrypted_keys,
            std::size_t received_keys_size, std::size_t received_data_size,
            std::vector<std::vector<ByteVector>>& received_keys, std::size_t point_len);

    // Exchanges a single column's encrypted keys or doublely encrypted keys with the other party.
    void exchange_single_encrypted_keys(const std::vector<ByteVector>& encrypted_keys, std::size_t received_data_size,
            std::vector<ByteVector>& received_keys, std::size_t point_len);

    // Exchanges encrypted features or encrypted additives shares with the other party.
    void exchange_encrypted_features(const std::vector<std::vector<ByteVector>>& encrypted_features,
            std::size_t self_paillier_len, std::size_t remote_paillier_len, std::size_t received_features_size,
            std::size_t received_data_size, std::vector<std::vector<ByteVector>>& received_features);

    // Resets data at the end of process function.
    void reset_data();

    bool is_sender_ = false;

    json params_ = "";
    bool verbose_ = false;

    std::unique_ptr<EccCipher> ecc_cipher_ = nullptr;
    std::size_t num_threads_ = 0;

    IpclPaillier sender_paillier_{};
    IpclPaillier receiver_paillier_{};
    bool apply_packing_ = false;
    std::size_t statistical_security_bits_ = 0;
    std::size_t slot_bits_ = 0;

    std::shared_ptr<IOBase> io_ = nullptr;

    std::size_t key_size_ = 0;
    std::size_t sender_data_size_ = 0;
    std::size_t sender_feature_size_ = 0;
    std::size_t receiver_data_size_ = 0;
    std::size_t receiver_feature_size_ = 0;

    std::vector<std::vector<std::string>> plaintext_keys_{};
    std::vector<std::vector<std::uint64_t>> plaintext_features_{};

    std::vector<std::size_t> sender_permutation_{};
    std::vector<std::size_t> receiver_permutation_{};

    std::vector<std::vector<ByteVector>> exchanged_keys_{};

    std::vector<std::pair<bool, ByteVector>> intersection_indices_{};
};

}  // namespace dpca_psi
}  // namespace privacy_go
